---
title: Оптимизация умножения матриц
description: Рассмотрим алгоритм перемножения матриц с использованием трёх вложенных циклов. Сложность такого алгоритма по определению должна составлять O(n³), но есть...
sections: [Перестановки,Вложенные циклы,Сравнение алгоритмов]
tags: [java,массивы,многомерные массивы,матрицы,строки,колонки,слои,циклы]
canonical_url: /ru/2021/12/09/optimizing-matrix-multiplication.html
url_translated: /en/2021/12/10/optimizing-matrix-multiplication.html
title_translated: Optimizing matrix multiplication
date: 2021.12.09
---

Рассмотрим алгоритм перемножения матриц с использованием трёх вложенных циклов. Сложность такого
алгоритма по определению должна составлять `O(n³)`, но есть особенности, связанные со средой 
выполнения — скорость работы алгоритма зависит от последовательности, в которой выполняются циклы.

Сравним различные варианты перестановок вложенных циклов и время выполнения алгоритмов.
Возьмём две матрицы: {`L×M`} и {`M×N`} &rarr; три цикла &rarr; шесть перестановок:
`LMN`, `LNM`, `MLN`, `MNL`, `NLM`, `NML`.

Быстрее других отрабатывают те алгоритмы, которые пишут данные в результирующую матрицу *построчно слоями*:
`LMN` и `MLN`, — разница в процентах к другим алгоритмам значительная и зависит от среды выполнения.

*Дальнейшая оптимизация: [Умножение матриц в параллельных потоках]({{ '/ru/2022/02/08/matrix-multiplication-parallel-streams.html' | relative_url }}).*

## Построчный алгоритм {#row-wise-algorithm}

Внешний цикл обходит строки первой матрицы `L`, далее идёт цикл по *общей стороне* двух матриц `M`
и за ним цикл по колонкам второй матрицы `N`. Запись в результирующую матрицу происходит построчно,
а каждая строка заполняется слоями.

```java
/**
 * @param l строки матрицы 'a'
 * @param m колонки матрицы 'a'
 *          и строки матрицы 'b'
 * @param n колонки матрицы 'b'
 * @param a первая матрица 'l×m'
 * @param b вторая матрица 'm×n'
 * @return результирующая матрица 'l×n'
 */
public static int[][] matrixMultiplicationLMN(int l, int m, int n, int[][] a, int[][] b) {
    // результирующая матрица
    int[][] c = new int[l][n];
    // обходим индексы строк матрицы 'a'
    for (int i = 0; i < l; i++)
        // обходим индексы общей стороны двух матриц:
        // колонок матрицы 'a' и строк матрицы 'b'
        for (int k = 0; k < m; k++)
            // обходим индексы колонок матрицы 'b'
            for (int j = 0; j < n; j++)
                // сумма произведений элементов i-ой строки
                // матрицы 'a' и j-ой колонки матрицы 'b'
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
```

## Послойный алгоритм {#layer-wise-algorithm}

Внешний цикл обходит *общую сторону* двух матриц `M`, далее идёт цикл по строкам первой матрицы
`L` и за ним цикл по колонкам второй матрицы `N`. Запись в результирующую матрицу происходит слоями,
а каждый слой заполняется построчно.

```java
/**
 * @param l строки матрицы 'a'
 * @param m колонки матрицы 'a'
 *          и строки матрицы 'b'
 * @param n колонки матрицы 'b'
 * @param a первая матрица 'l×m'
 * @param b вторая матрица 'm×n'
 * @return результирующая матрица 'l×n'
 */
public static int[][] matrixMultiplicationMLN(int l, int m, int n, int[][] a, int[][] b) {
    // результирующая матрица
    int[][] c = new int[l][n];
    // обходим индексы общей стороны двух матриц:
    // колонок матрицы 'a' и строк матрицы 'b'
    for (int k = 0; k < m; k++)
        // обходим индексы строк матрицы 'a'
        for (int i = 0; i < l; i++)
            // обходим индексы колонок матрицы 'b'
            for (int j = 0; j < n; j++)
                // сумма произведений элементов i-ой строки
                // матрицы 'a' и j-ой колонки матрицы 'b'
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
```

### Прочие алгоритмы {#other-algorithms}

Обход колонок второй матрицы `N` происходит перед обходом *общей стороны* двух матриц `M`
и/или перед обходом строк первой матрицы `L`.

{% capture collapsed_md %}
```java
public static int[][] matrixMultiplicationLNM(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int i = 0; i < l; i++)
        for (int j = 0; j < n; j++)
            for (int k = 0; k < m; k++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
```
```java
public static int[][] matrixMultiplicationNLM(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int j = 0; j < n; j++)
        for (int i = 0; i < l; i++)
            for (int k = 0; k < m; k++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
```
```java
public static int[][] matrixMultiplicationMNL(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int k = 0; k < m; k++)
        for (int j = 0; j < n; j++)
            for (int i = 0; i < l; i++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
```
```java
public static int[][] matrixMultiplicationNML(int l, int m, int n, int[][] a, int[][] b) {
    int[][] c = new int[l][n];
    for (int j = 0; j < n; j++)
        for (int k = 0; k < m; k++)
            for (int i = 0; i < l; i++)
                c[i][j] += a[i][k] * b[k][j];
    return c;
}
```
{% endcapture %}
{%- include collapsed_block.html summary="Код без комментариев" content=collapsed_md -%}

## Сравнение алгоритмов {#comparing-algorithms}

Для проверки возьмём две матрицы `A=[500×700]` и `B=[700×450]`, заполненные случайными числами.
Сначала сравниваем между собой корректность реализации алгоритмов — все полученные результаты
должны совпадать. Затем выполняем каждый метод по 10 раз и подсчитываем среднее время выполнения.

```java
// запускаем программу и выводим результат
public static void main(String[] args) throws Exception {
    // входящие данные
    int l = 500, m = 700, n = 450, steps = 10;
    int[][] a = randomMatrix(l, m), b = randomMatrix(m, n);
    // карта методов для сравнения
    var methods = new TreeMap<String, Callable<int[][]>>(Map.of(
            "LMN", () -> matrixMultiplicationLMN(l, m, n, a, b),
            "LNM", () -> matrixMultiplicationLNM(l, m, n, a, b),
            "MLN", () -> matrixMultiplicationMLN(l, m, n, a, b),
            "MNL", () -> matrixMultiplicationMNL(l, m, n, a, b),
            "NLM", () -> matrixMultiplicationNLM(l, m, n, a, b),
            "NML", () -> matrixMultiplicationNML(l, m, n, a, b)));
    int[][] last = null;
    // обходим карту методов, проверяем корректность результатов,
    // все полученные результаты должны быть равны друг другу
    for (var method : methods.entrySet()) {
        // следующий метод для сравнения
        var next = methods.higherEntry(method.getKey());
        // если текущий метод не последний — сравниваем результаты двух методов
        if (next != null) System.out.println(method.getKey() + "=" + next.getKey() + ": "
                // сравниваем результат выполнения текущего метода и следующего за ним
                + Arrays.deepEquals(method.getValue().call(), next.getValue().call()));
            // результат выполнения последнего метода
        else last = method.getValue().call();
    }
    int[][] test = last;
    // обходим карту методов, замеряем время работы каждого метода
    for (var method : methods.entrySet())
        // параметры: заголовок, количество шагов, исполняемый код 
        benchmark(method.getKey(), steps, () -> {
            try { // выполняем метод, получаем результат
                int[][] result = method.getValue().call();
                // проверяем корректность результатов на каждом шаге
                if (!Arrays.deepEquals(result, test)) System.out.print("error");
            } catch (Exception e) {
                e.printStackTrace();
            }
        });
}
```

{% capture collapsed_md %}
```java
// вспомогательный метод, возвращает матрицу указанного размера
private static int[][] randomMatrix(int row, int col) {
    int[][] matrix = new int[row][col];
    for (int i = 0; i < row; i++)
        for (int j = 0; j < col; j++)
            matrix[i][j] = (int) (Math.random() * row * col);
    return matrix;
}
```
```java
// вспомогательный метод для замера времени работы переданного кода
private static void benchmark(String title, int steps, Runnable runnable) {
    long time, avg = 0;
    System.out.print(title);
    for (int i = 0; i < steps; i++) {
        time = System.currentTimeMillis();
        runnable.run();
        time = System.currentTimeMillis() - time;
        // время выполнения одного шага
        System.out.print(" | " + time);
        avg += time;
    }
    // среднее время выполнения
    System.out.println(" || " + (avg / steps));
}
```
{% endcapture %}
{%- include collapsed_block.html summary="Вспомогательные методы" content=collapsed_md -%}

Вывод зависит от среды выполнения, время в миллисекундах:

```
LMN=LNM: true
LNM=MLN: true
MLN=MNL: true
MNL=NLM: true
NLM=NML: true
LMN | 191 | 109 | 105 | 106 | 105 | 106 | 106 | 105 | 123 | 109 || 116
LNM | 417 | 418 | 419 | 416 | 416 | 417 | 418 | 417 | 416 | 417 || 417
MLN | 113 | 115 | 113 | 115 | 114 | 114 | 114 | 115 | 114 | 113 || 114
MNL | 857 | 864 | 857 | 859 | 860 | 863 | 862 | 860 | 858 | 860 || 860
NLM | 404 | 404 | 407 | 404 | 406 | 405 | 405 | 404 | 403 | 404 || 404
NML | 866 | 872 | 867 | 868 | 867 | 868 | 867 | 873 | 869 | 863 || 868
```

Все описанные выше методы, включая свёрнутые блоки, можно поместить в одном классе.

{% capture collapsed_md %}
```java
import java.util.Arrays;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.Callable;
```
{% endcapture %}
{%- include collapsed_block.html summary="Необходимые импорты" content=collapsed_md -%}
